{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3mqoATXZGTAX"
   },
   "source": [
    "Copyright (c) MONAI Consortium  \n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");  \n",
    "you may not use this file except in compliance with the License.  \n",
    "You may obtain a copy of the License at  \n",
    "&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  \n",
    "Unless required by applicable law or agreed to in writing, software  \n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,  \n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  \n",
    "See the License for the specific language governing permissions and  \n",
    "limitations under the License.\n",
    "\n",
    "# Experiment Management with ClearML\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/main/experiment_management/unet_segmentation_3d_ignite_clearml.ipynb)\n",
    "\n",
    "This tutorial shows how to use ClearML to manage MONAI experiments. You can integrate ClearML into your code using MONAI's built-in handlers: `ClearMLImageHandler`, `ClearMLStatsHandler`, and `ModelCheckpoint`.\n",
    "\n",
    "The MONAI example used here is [3D segmentation with UNet](https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/unet_segmentation_3d_ignite.ipynb)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## Setup environment\n",
    "\n",
    "`clearml` comes as part of the `monai[all]` installation. For more information about integrating ClearML into your MONAI code, see [here](https://clear.ml/docs/latest/docs/integrations/monai). For more information about using ClearML (experiment management, data management, pipelines, model serving, and more), see [ClearML's official documentation](https://clear.ml/docs/latest/docs/)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "!python -c \"import monai\" || pip install -q \"monai-weekly[ignite, nibabel, tensorboard, clearml]\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "G-L36u4sGTAg"
   },
   "source": [
    "## Setup imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "VAWF4EftGTAh"
   },
   "outputs": [],
   "source": [
    "import glob\n",
    "import logging\n",
    "import os\n",
    "import shutil\n",
    "import sys\n",
    "import tempfile\n",
    "\n",
    "import nibabel as nib\n",
    "import numpy as np\n",
    "from monai.config import print_config\n",
    "from monai.data import ArrayDataset, create_test_image_3d, decollate_batch, DataLoader\n",
    "from monai.handlers import (\n",
    "    MeanDice,\n",
    "    StatsHandler,\n",
    ")\n",
    "\n",
    "# import the clearml handlers\n",
    "from monai.handlers.clearml_handlers import ClearMLImageHandler, ClearMLStatsHandler\n",
    "from monai.losses import DiceLoss\n",
    "from monai.networks.nets import UNet\n",
    "from monai.transforms import (\n",
    "    Activations,\n",
    "    EnsureChannelFirst,\n",
    "    AsDiscrete,\n",
    "    Compose,\n",
    "    LoadImage,\n",
    "    RandSpatialCrop,\n",
    "    Resize,\n",
    "    ScaleIntensity,\n",
    ")\n",
    "from monai.utils import first\n",
    "\n",
    "import ignite\n",
    "import torch\n",
    "\n",
    "print_config()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## Setup ClearML ⚓\n",
    "\n",
    "1. To keep track of your experiments and/or data, ClearML needs to communicate to a server. You have 2 server options:\n",
    "\n",
    "  * Sign up for free to the [ClearML Hosted Service](https://app.clear.ml/)\n",
    "  * Set up your own server, see [here](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server).\n",
    "\n",
    "2. Add your ClearML credentials below. ClearML will use these credentials to connect to your server (see instructions for generating credentials [here](https://clear.ml/docs/latest/docs/getting_started/ds/ds_first_steps/#jupyter-notebook))."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# clearml credentials\n",
    "%env CLEARML_WEB_HOST=''\n",
    "%env CLEARML_API_HOST=''\n",
    "%env CLEARML_FILES_HOST=''\n",
    "%env CLEARML_API_ACCESS_KEY=''\n",
    "%env CLEARML_API_SECRET_KEY=''"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "cBN8rQJcGTAk"
   },
   "source": [
    "## Setup data directory\n",
    "\n",
    "You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  \n",
    "This allows you to save results and reuse downloads.  \n",
    "If not specified a temporary directory will be used."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "BSjqI-ZBGTAk"
   },
   "outputs": [],
   "source": [
    "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n",
    "if directory is not None:\n",
    "    os.makedirs(directory, exist_ok=True)\n",
    "root_dir = tempfile.mkdtemp() if directory is None else directory\n",
    "print(root_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LzxP-l1nGTAl"
   },
   "source": [
    "## Setup logging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "GGHooZZwGTAl"
   },
   "outputs": [],
   "source": [
    "logging.basicConfig(stream=sys.stdout, level=logging.INFO)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "eIDTWNVxGTAm"
   },
   "source": [
    "## Setup demo data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "l0jR_jz8GTAn"
   },
   "outputs": [],
   "source": [
    "for i in range(40):\n",
    "    im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)\n",
    "\n",
    "    n = nib.Nifti1Image(im, np.eye(4))\n",
    "    nib.save(n, os.path.join(root_dir, f\"im{i}.nii.gz\"))\n",
    "\n",
    "    n = nib.Nifti1Image(seg, np.eye(4))\n",
    "    nib.save(n, os.path.join(root_dir, f\"seg{i}.nii.gz\"))\n",
    "\n",
    "images = sorted(glob.glob(os.path.join(root_dir, \"im*.nii.gz\")))\n",
    "segs = sorted(glob.glob(os.path.join(root_dir, \"seg*.nii.gz\")))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8FUIYNO0GTAn"
   },
   "source": [
    "## Setup transforms, dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "g5gz3rECGTAo"
   },
   "outputs": [],
   "source": [
    "# Define transforms for image and segmentation\n",
    "imtrans = Compose(\n",
    "    [\n",
    "        LoadImage(image_only=True),\n",
    "        ScaleIntensity(),\n",
    "        EnsureChannelFirst(),\n",
    "        RandSpatialCrop((96, 96, 96), random_size=False),\n",
    "    ]\n",
    ")\n",
    "segtrans = Compose(\n",
    "    [\n",
    "        LoadImage(image_only=True),\n",
    "        EnsureChannelFirst(),\n",
    "        RandSpatialCrop((96, 96, 96), random_size=False),\n",
    "    ]\n",
    ")\n",
    "\n",
    "# Define nifti dataset, dataloader\n",
    "ds = ArrayDataset(images, imtrans, segs, segtrans)\n",
    "# loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available())\n",
    "loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=False)\n",
    "im, seg = first(loader)\n",
    "print(im.shape, seg.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6zl4_zeSGTAp"
   },
   "source": [
    "## Create Model, Loss, Optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "XT1D-a1-GTAp"
   },
   "outputs": [],
   "source": [
    "# Create UNet, DiceLoss and Adam optimizer\n",
    "\n",
    "device = None  # torch.device(\"cuda:0\")\n",
    "net = UNet(\n",
    "    spatial_dims=3,\n",
    "    in_channels=1,\n",
    "    out_channels=1,\n",
    "    channels=(16, 32, 64, 128, 256),\n",
    "    strides=(2, 2, 2, 2),\n",
    "    num_res_units=2,\n",
    ").to(device)\n",
    "\n",
    "loss = DiceLoss(sigmoid=True)\n",
    "lr = 1e-3\n",
    "opt = torch.optim.Adam(net.parameters(), lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Xvduwsk8GTBm"
   },
   "source": [
    "## Create supervised_trainer using ignite"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9Ocpt_MOGTBn"
   },
   "outputs": [],
   "source": [
    "# Create trainer\n",
    "trainer = ignite.engine.create_supervised_trainer(net, opt, loss, device, False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5vo9Eq33GTBn"
   },
   "source": [
    "## Set up event handlers for checkpointing and logging\n",
    "\n",
    "Using a ClearML handler creates a ClearML Task, which captures your experiment's models, scalars, images, and more.\n",
    "\n",
    "The console output displays the task ID and a link to the task's page in the ClearML web UI."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "cZTtFnySGTBo"
   },
   "outputs": [],
   "source": [
    "# optional section for checkpoint and clearml logging\n",
    "# adding checkpoint handler to save models (network\n",
    "# params and optimizer stats) during training\n",
    "log_dir = os.path.join(root_dir, \"logs\")\n",
    "checkpoint_handler = ignite.handlers.ModelCheckpoint(log_dir, \"net\", n_saved=10, require_empty=False)\n",
    "trainer.add_event_handler(\n",
    "    event_name=ignite.engine.Events.EPOCH_COMPLETED,\n",
    "    handler=checkpoint_handler,\n",
    "    to_save={\"net\": net, \"opt\": opt},\n",
    ")\n",
    "\n",
    "# StatsHandler prints loss at every iteration\n",
    "# user can also customize print functions and can use output_transform to convert\n",
    "# engine.state.output if it's not a loss value\n",
    "train_stats_handler = StatsHandler(name=\"trainer\", output_transform=lambda x: x)\n",
    "train_stats_handler.attach(trainer)\n",
    "\n",
    "\n",
    "# ClearMLStatsHandler plots loss at every iteration\n",
    "# Creates a ClearML Task which logs the scalar plots\n",
    "task_name = \"UNet segmentation 3d\"\n",
    "project_name = \"MONAI example\"\n",
    "\n",
    "train_clearml_stats_handler = ClearMLStatsHandler(\n",
    "    task_name=task_name, project_name=project_name, log_dir=log_dir, output_transform=lambda x: x\n",
    ")\n",
    "train_clearml_stats_handler.attach(trainer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hBv81ENGGTBp"
   },
   "source": [
    "## Add Validation every N epochs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "VYte_Cw6GTBp"
   },
   "outputs": [],
   "source": [
    "# optional section for model validation during training\n",
    "validation_every_n_epochs = 1\n",
    "# Set parameters for validation\n",
    "metric_name = \"Mean_Dice\"\n",
    "# add evaluation metric to the evaluator engine\n",
    "val_metrics = {metric_name: MeanDice()}\n",
    "post_pred = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])\n",
    "post_label = Compose([AsDiscrete(threshold=0.5)])\n",
    "# Ignite evaluator expects batch=(img, seg) and\n",
    "# returns output=(y_pred, y) at every iteration,\n",
    "# user can add output_transform to return other values\n",
    "evaluator = ignite.engine.create_supervised_evaluator(\n",
    "    net,\n",
    "    val_metrics,\n",
    "    device,\n",
    "    True,\n",
    "    output_transform=lambda x, y, y_pred: (\n",
    "        [post_pred(i) for i in decollate_batch(y_pred)],\n",
    "        [post_label(i) for i in decollate_batch(y)],\n",
    "    ),\n",
    ")\n",
    "\n",
    "# create a validation data loader\n",
    "val_imtrans = Compose(\n",
    "    [\n",
    "        LoadImage(image_only=True),\n",
    "        ScaleIntensity(),\n",
    "        EnsureChannelFirst(),\n",
    "        Resize((96, 96, 96)),\n",
    "    ]\n",
    ")\n",
    "val_segtrans = Compose(\n",
    "    [\n",
    "        LoadImage(image_only=True),\n",
    "        EnsureChannelFirst(),\n",
    "        Resize((96, 96, 96)),\n",
    "    ]\n",
    ")\n",
    "val_ds = ArrayDataset(images[21:], val_imtrans, segs[21:], val_segtrans)\n",
    "# val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available())\n",
    "val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, pin_memory=False)\n",
    "\n",
    "\n",
    "@trainer.on(ignite.engine.Events.EPOCH_COMPLETED(every=validation_every_n_epochs))\n",
    "def run_validation(engine):\n",
    "    evaluator.run(val_loader)\n",
    "\n",
    "\n",
    "# Add stats event handler to print validation stats via evaluator\n",
    "val_stats_handler = StatsHandler(\n",
    "    name=\"evaluator\",\n",
    "    # no need to print loss value, so disable per iteration output\n",
    "    output_transform=lambda x: None,\n",
    "    # fetch global epoch number from trainer\n",
    "    global_epoch_transform=lambda x: trainer.state.epoch,\n",
    ")\n",
    "val_stats_handler.attach(evaluator)\n",
    "\n",
    "# add handler to record metrics to clearml at every validation epoch\n",
    "val_clearml_stats_handler = ClearMLStatsHandler(\n",
    "    log_dir=log_dir,\n",
    "    # no need to plot loss value, so disable per iteration output\n",
    "    output_transform=lambda x: None,\n",
    "    # fetch global epoch number from trainer\n",
    "    global_epoch_transform=lambda x: trainer.state.epoch,\n",
    ")\n",
    "val_clearml_stats_handler.attach(evaluator)\n",
    "\n",
    "# add handler to draw the first image and the corresponding\n",
    "# label and model output in the last batch\n",
    "# here we draw the 3D output as GIF format along Depth\n",
    "# axis, at every validation epoch\n",
    "val_clearml_image_handler = ClearMLImageHandler(\n",
    "    task_name=task_name,\n",
    "    project_name=project_name,\n",
    "    log_dir=log_dir,\n",
    "    batch_transform=lambda batch: (batch[0], batch[1]),\n",
    "    output_transform=lambda output: output[0],\n",
    "    global_iter_transform=lambda x: trainer.state.epoch,\n",
    ")\n",
    "evaluator.add_event_handler(\n",
    "    event_name=ignite.engine.Events.EPOCH_COMPLETED,\n",
    "    handler=val_clearml_image_handler,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "VxXixOfHGTBq"
   },
   "source": [
    "## Run training loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "kW0K4mxIGTBt"
   },
   "outputs": [],
   "source": [
    "# create a training data loader\n",
    "train_ds = ArrayDataset(images[:20], imtrans, segs[:20], segtrans)\n",
    "train_loader = DataLoader(\n",
    "    train_ds,\n",
    "    batch_size=5,\n",
    "    shuffle=True,\n",
    "    num_workers=8,\n",
    "    # pin_memory=torch.cuda.is_available(),\n",
    "    pin_memory=False,\n",
    ")\n",
    "\n",
    "max_epochs = 10\n",
    "trainer.run(train_loader, max_epochs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7bPDjtbE7sCP"
   },
   "source": [
    "## Visualize results\n",
    "\n",
    "ClearML captures the models, scalar plots, and images logged with `ModelCheckpoint`, `ClearMLImageHandler`, and `ClearMLStatsHandler` respectively. View them in ClearML's web UI. When a task is created, the console output displays the task ID and a link to the task's page in the ClearML web UI.\n",
    "\n",
    "### Models\n",
    "All model checkpoints logged with ModelCheckpoint can be viewed in the task's **Artifacts** tab.\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Uya6hN7At-C6"
   },
   "source": [
    "![MONAI ClearML Models](./../figures/monai_clearml_models.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kLs1tEXRt-Pp"
   },
   "source": [
    "### Scalars\n",
    "\n",
    "View the logged metric plots in the task's **Scalars** tab.\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "DCLhp4e2uKsd"
   },
   "source": [
    "![MONAI ClearML scalars.png](./../figures/monai_clearml_scalars.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "yjfeorgJuLUN"
   },
   "source": [
    "### Debug Samples\n",
    "\n",
    "View all images logged through ClearMLImageHandler in the task's **Debug Samples** tab. You can view the samples by metric at any iteration."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "AFLJauKxt-Yy"
   },
   "source": [
    "![MONAI ClearML Debug Samples.png](./../figures/monai_clearml_debug_samples.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UFlaQUTXGTBv"
   },
   "source": [
    "## Cleanup data directory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "w6JE9lVQGTBw"
   },
   "outputs": [],
   "source": [
    "if directory is None:\n",
    "    shutil.rmtree(root_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "sT-5Zxg_7Ajm"
   },
   "source": [
    "\n",
    "## Close the ClearML Task\n",
    "This changes the task status to `Completed`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "BAD6ACGGOLU3"
   },
   "outputs": [],
   "source": [
    "val_clearml_image_handler.clearml_task.close()"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
